{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7765UFHoyGx6"
      },
      "source": [
        "##### Copyright 2020 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "KsOkK8O69PyT"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RKQpW0JqQQmY"
      },
      "source": [
        "# TensorFlow Lattice を使った形状制約\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "r61fkA2i9Y3_"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td><a target=\"_blank\" href=\"https://www.tensorflow.org/lattice/tutorials/shape_constraints\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\"> TensorFlow.orgで表示</a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs-l10n/blob/master/site/ja/lattice/tutorials/shape_constraints.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\"> Google Colab で実行</a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://github.com/tensorflow/docs-l10n/blob/master/site/ja/lattice/tutorials/shape_constraints.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\">GitHub でソースを表示{</a></td>\n",
        "  <td><a href=\"https://storage.googleapis.com/tensorflow_docs/docs-l10n/site/ja/lattice/tutorials/shape_constraints.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\">ノートブックをダウンロード/a0}</a></td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2plcL3iTVjsp"
      },
      "source": [
        "## 概要\n",
        "\n",
        "このチュートリアルは、TensorFlow Lattice（TFL）ライブラリが提供する制約と正規化の概要です。ここでは、合成データセットに TFL 缶詰 Estimator を使用しますが、このチュートリアルの内容は TFL Keras レイヤーから構築されたモデルでも実行できます。\n",
        "\n",
        "続行する前に、ランタイムに必要なすべてのパッケージがインストールされていることを確認してください（以下のコードセルでインポートされるとおりに行います）。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "x769lI12IZXB"
      },
      "source": [
        "## セットアップ"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fbBVAR6UeRN5"
      },
      "source": [
        "TF Lattice パッケージをインストールします。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bpXjJKpSd3j4"
      },
      "outputs": [],
      "source": [
        "#@test {\"skip\": true}\n",
        "!pip install tensorflow-lattice"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jSVl9SHTeSGX"
      },
      "source": [
        "必要なパッケージをインポートします。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "both",
        "id": "iY6awAl058TV"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "\n",
        "from IPython.core.pylabtools import figsize\n",
        "import itertools\n",
        "import logging\n",
        "import matplotlib\n",
        "from matplotlib import pyplot as plt\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "import sys\n",
        "import tensorflow_lattice as tfl\n",
        "logging.disable(sys.maxsize)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7TmBk_IGgJF0"
      },
      "source": [
        "このガイドで使用されるデフォルト値です。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kQHPyPsPUF92"
      },
      "outputs": [],
      "source": [
        "NUM_EPOCHS = 500\n",
        "BATCH_SIZE = 64\n",
        "LEARNING_RATE=0.001"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FjR7D8Ag3z0d"
      },
      "source": [
        "## レストランのランク付けに使用するトレーニングデータセット"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "a1YetzbdFOij"
      },
      "source": [
        "ユーザーがレストランの検索結果をクリックするかどうかを判定する、単純なシナリオを想定しましょう。このタスクでは、次の特定の入力特徴量でクリック率（CTR）を予測します。\n",
        "\n",
        "- 平均評価（`avg_rating`）: [1,5] の範囲の値による数値特徴量。\n",
        "- レビュー数（`num_reviews`）: 最大値 200 の数値特徴量。流行状況の測定値として使用します。\n",
        "- ドル記号評価（`dollar_rating`）: {\"D\", \"DD\", \"DDD\", \"DDDD\"} セットの文字列値による分類特徴量。\n",
        "\n",
        "ここでは、真の CTR を式 $$ CTR = 1 / (1 + exp{\\mbox{b(dollar_rating)}-\\mbox{avg_rating}\\times log(\\mbox{num_reviews}) /4 }) $$ で得る合成データセットを作成します。$b(\\cdot)$ は各 `dollar_rating` をベースラインの値 $$ \\mbox{D}\\to 3,\\ \\mbox{DD}\\to 2,\\ \\mbox{DDD}\\to 4,\\ \\mbox{DDDD}\\to 4.5. $$ に変換します。\n",
        "\n",
        "この式は、典型的なユーザーパターンを反映します。たとえば、ほかのすべてが固定された状態で、ユーザーは星評価の高いレストランを好み、\"\\$\\$\" のレストランは \"\\$\" のレストランよりも多いクリック率を得、\"\\$\\$\\$\"、\"\\$\\$\\$\\$\" となればさらに多いクリック率を得るというパターンです。 "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mKovnyv1jATw"
      },
      "outputs": [],
      "source": [
        "def click_through_rate(avg_ratings, num_reviews, dollar_ratings):\n",
        "  dollar_rating_baseline = {\"D\": 3, \"DD\": 2, \"DDD\": 4, \"DDDD\": 4.5}\n",
        "  return 1 / (1 + np.exp(\n",
        "      np.array([dollar_rating_baseline[d] for d in dollar_ratings]) -\n",
        "      avg_ratings * np.log1p(num_reviews) / 4))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BPlgRdt6jAbP"
      },
      "source": [
        "この CTR 関数の等高線図を見てみましょう。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KC5qX_XKmc7g"
      },
      "outputs": [],
      "source": [
        "def color_bar():\n",
        "  bar = matplotlib.cm.ScalarMappable(\n",
        "      norm=matplotlib.colors.Normalize(0, 1, True),\n",
        "      cmap=\"viridis\",\n",
        "  )\n",
        "  bar.set_array([0, 1])\n",
        "  return bar\n",
        "\n",
        "\n",
        "def plot_fns(fns, split_by_dollar=False, res=25):\n",
        "  \"\"\"Generates contour plots for a list of (name, fn) functions.\"\"\"\n",
        "  num_reviews, avg_ratings = np.meshgrid(\n",
        "      np.linspace(0, 200, num=res),\n",
        "      np.linspace(1, 5, num=res),\n",
        "  )\n",
        "  if split_by_dollar:\n",
        "    dollar_rating_splits = [\"D\", \"DD\", \"DDD\", \"DDDD\"]\n",
        "  else:\n",
        "    dollar_rating_splits = [None]\n",
        "  if len(fns) == 1:\n",
        "    fig, axes = plt.subplots(2, 2, sharey=True, tight_layout=False)\n",
        "  else:\n",
        "    fig, axes = plt.subplots(\n",
        "        len(dollar_rating_splits), len(fns), sharey=True, tight_layout=False)\n",
        "  axes = axes.flatten()\n",
        "  axes_index = 0\n",
        "  for dollar_rating_split in dollar_rating_splits:\n",
        "    for title, fn in fns:\n",
        "      if dollar_rating_split is not None:\n",
        "        dollar_ratings = np.repeat(dollar_rating_split, res**2)\n",
        "        values = fn(avg_ratings.flatten(), num_reviews.flatten(),\n",
        "                    dollar_ratings)\n",
        "        title = \"{}: dollar_rating={}\".format(title, dollar_rating_split)\n",
        "      else:\n",
        "        values = fn(avg_ratings.flatten(), num_reviews.flatten())\n",
        "      subplot = axes[axes_index]\n",
        "      axes_index += 1\n",
        "      subplot.contourf(\n",
        "          avg_ratings,\n",
        "          num_reviews,\n",
        "          np.reshape(values, (res, res)),\n",
        "          vmin=0,\n",
        "          vmax=1)\n",
        "      subplot.title.set_text(title)\n",
        "      subplot.set(xlabel=\"Average Rating\")\n",
        "      subplot.set(ylabel=\"Number of Reviews\")\n",
        "      subplot.set(xlim=(1, 5))\n",
        "\n",
        "  _ = fig.colorbar(color_bar(), cax=fig.add_axes([0.95, 0.2, 0.01, 0.6]))\n",
        "\n",
        "\n",
        "figsize(11, 11)\n",
        "plot_fns([(\"CTR\", click_through_rate)], split_by_dollar=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ol91olp3muNN"
      },
      "source": [
        "### データを準備する\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "H8BOshZS9xwn"
      },
      "source": [
        "次に、合成データセットを作成する必要があります。シミュレーション済みのレストランのデータセットとその特徴量を生成するところから始めます。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MhqcOPdTT_wj"
      },
      "outputs": [],
      "source": [
        "def sample_restaurants(n):\n",
        "  avg_ratings = np.random.uniform(1.0, 5.0, n)\n",
        "  num_reviews = np.round(np.exp(np.random.uniform(0.0, np.log(200), n)))\n",
        "  dollar_ratings = np.random.choice([\"D\", \"DD\", \"DDD\", \"DDDD\"], n)\n",
        "  ctr_labels = click_through_rate(avg_ratings, num_reviews, dollar_ratings)\n",
        "  return avg_ratings, num_reviews, dollar_ratings, ctr_labels\n",
        "\n",
        "\n",
        "np.random.seed(42)\n",
        "avg_ratings, num_reviews, dollar_ratings, ctr_labels = sample_restaurants(2000)\n",
        "\n",
        "figsize(5, 5)\n",
        "fig, axs = plt.subplots(1, 1, sharey=False, tight_layout=False)\n",
        "for rating, marker in [(\"D\", \"o\"), (\"DD\", \"^\"), (\"DDD\", \"+\"), (\"DDDD\", \"x\")]:\n",
        "  plt.scatter(\n",
        "      x=avg_ratings[np.where(dollar_ratings == rating)],\n",
        "      y=num_reviews[np.where(dollar_ratings == rating)],\n",
        "      c=ctr_labels[np.where(dollar_ratings == rating)],\n",
        "      vmin=0,\n",
        "      vmax=1,\n",
        "      marker=marker,\n",
        "      label=rating)\n",
        "plt.xlabel(\"Average Rating\")\n",
        "plt.ylabel(\"Number of Reviews\")\n",
        "plt.legend()\n",
        "plt.xlim((1, 5))\n",
        "plt.title(\"Distribution of restaurants\")\n",
        "_ = fig.colorbar(color_bar(), cax=fig.add_axes([0.95, 0.2, 0.01, 0.6]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tRetsfLv_JSR"
      },
      "source": [
        "トレーニング、評価、およびテストデータセットを生成しましょう。検索結果でレストランが閲覧されるときに、ユーザーのエンゲージメント（クリック有りまたはクリック無し）をサンプルポイントとして記録できます。\n",
        "\n",
        "実際には、ユーザーが全検索結果を見ることはほとんどありません。つまり、ユーザーは、使用されている現在のランキングモデルによってすでに「良い」とみなされているレストランのみを閲覧する傾向にあるでしょう。そのため、トレーニングデータセットでは「良い」レストランはより頻繁に表示されて、過剰表現されます。さらに多くの特徴量を使用する際に、トレーニングデータセットでは、特徴量空間の「悪い」部分に大きなギャップが生じてしまいます。\n",
        "\n",
        "モデルがランキングに使用される場合、トレーニングデータセットで十分に表現されていないより均一な分布を持つ、すべての関連結果で評価されることがほとんどです。この場合、過剰に表現されたデータポイントの過適合によって一般化可能性に欠けることから、柔軟で複雑なモデルは失敗する可能性があります。この問題には、トレーニングデータセットから形状制約を拾えない場合に合理的な予測を立てられるようにモデルを誘導する*形状制約*を追加するドメインナレッジを適用して対処します。\n",
        "\n",
        "この例では、トレーニングデータセットは、人気のある良いレストランとのユーザーインタラクションで構成されており、テストデータセットには、上記で説明した評価設定をシミュレーションする一様分布があります。このようなテストデータセットは、実際の問題設定では利用できないことに注意してください。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jS6WOtXQ8jwX"
      },
      "outputs": [],
      "source": [
        "def sample_dataset(n, testing_set):\n",
        "  (avg_ratings, num_reviews, dollar_ratings, ctr_labels) = sample_restaurants(n)\n",
        "  if testing_set:\n",
        "    # Testing has a more uniform distribution over all restaurants.\n",
        "    num_views = np.random.poisson(lam=3, size=n)\n",
        "  else:\n",
        "    # Training/validation datasets have more views on popular restaurants.\n",
        "    num_views = np.random.poisson(lam=ctr_labels * num_reviews / 50.0, size=n)\n",
        "\n",
        "  return pd.DataFrame({\n",
        "      \"avg_rating\": np.repeat(avg_ratings, num_views),\n",
        "      \"num_reviews\": np.repeat(num_reviews, num_views),\n",
        "      \"dollar_rating\": np.repeat(dollar_ratings, num_views),\n",
        "      \"clicked\": np.random.binomial(n=1, p=np.repeat(ctr_labels, num_views))\n",
        "  })\n",
        "\n",
        "\n",
        "# Generate datasets.\n",
        "np.random.seed(42)\n",
        "data_train = sample_dataset(2000, testing_set=False)\n",
        "data_val = sample_dataset(1000, testing_set=False)\n",
        "data_test = sample_dataset(1000, testing_set=True)\n",
        "\n",
        "# Plotting dataset densities.\n",
        "figsize(12, 5)\n",
        "fig, axs = plt.subplots(1, 2, sharey=False, tight_layout=False)\n",
        "for ax, data, title in [(axs[0], data_train, \"training\"),\n",
        "                        (axs[1], data_test, \"testing\")]:\n",
        "  _, _, _, density = ax.hist2d(\n",
        "      x=data[\"avg_rating\"],\n",
        "      y=data[\"num_reviews\"],\n",
        "      bins=(np.linspace(1, 5, num=21), np.linspace(0, 200, num=21)),\n",
        "      density=True,\n",
        "      cmap=\"Blues\",\n",
        "  )\n",
        "  ax.set(xlim=(1, 5))\n",
        "  ax.set(ylim=(0, 200))\n",
        "  ax.set(xlabel=\"Average Rating\")\n",
        "  ax.set(ylabel=\"Number of Reviews\")\n",
        "  ax.title.set_text(\"Density of {} examples\".format(title))\n",
        "  _ = fig.colorbar(density, ax=ax)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4fVyLgpCT1nW"
      },
      "source": [
        "トレーニングと評価に使用する input_fn を定義します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DYzRTRR2GKoS"
      },
      "outputs": [],
      "source": [
        "train_input_fn = tf.compat.v1.estimator.inputs.pandas_input_fn(\n",
        "    x=data_train,\n",
        "    y=data_train[\"clicked\"],\n",
        "    batch_size=BATCH_SIZE,\n",
        "    num_epochs=NUM_EPOCHS,\n",
        "    shuffle=False,\n",
        ")\n",
        "\n",
        "# feature_analysis_input_fn is used for TF Lattice estimators.\n",
        "feature_analysis_input_fn = tf.compat.v1.estimator.inputs.pandas_input_fn(\n",
        "    x=data_train,\n",
        "    y=data_train[\"clicked\"],\n",
        "    batch_size=BATCH_SIZE,\n",
        "    num_epochs=1,\n",
        "    shuffle=False,\n",
        ")\n",
        "\n",
        "val_input_fn = tf.compat.v1.estimator.inputs.pandas_input_fn(\n",
        "    x=data_val,\n",
        "    y=data_val[\"clicked\"],\n",
        "    batch_size=BATCH_SIZE,\n",
        "    num_epochs=1,\n",
        "    shuffle=False,\n",
        ")\n",
        "\n",
        "test_input_fn = tf.compat.v1.estimator.inputs.pandas_input_fn(\n",
        "    x=data_test,\n",
        "    y=data_test[\"clicked\"],\n",
        "    batch_size=BATCH_SIZE,\n",
        "    num_epochs=1,\n",
        "    shuffle=False,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qoTrw3FZqvPK"
      },
      "source": [
        "## 勾配ブースティング木を適合させる"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZklNowexE3wB"
      },
      "source": [
        "まずは、`avg_rating` と `num_reviews` の 2 つの特徴量から始めましょう。\n",
        "\n",
        "検証とテストのメトリックを描画および計算する補助関数をいくつか作成します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SX6rARJWURWl"
      },
      "outputs": [],
      "source": [
        "def analyze_two_d_estimator(estimator, name):\n",
        "  # Extract validation metrics.\n",
        "  metric = estimator.evaluate(input_fn=val_input_fn)\n",
        "  print(\"Validation AUC: {}\".format(metric[\"auc\"]))\n",
        "  metric = estimator.evaluate(input_fn=test_input_fn)\n",
        "  print(\"Testing AUC: {}\".format(metric[\"auc\"]))\n",
        "\n",
        "  def two_d_pred(avg_ratings, num_reviews):\n",
        "    results = estimator.predict(\n",
        "        tf.compat.v1.estimator.inputs.pandas_input_fn(\n",
        "            x=pd.DataFrame({\n",
        "                \"avg_rating\": avg_ratings,\n",
        "                \"num_reviews\": num_reviews,\n",
        "            }),\n",
        "            shuffle=False,\n",
        "        ))\n",
        "    return [x[\"logistic\"][0] for x in results]\n",
        "\n",
        "  def two_d_click_through_rate(avg_ratings, num_reviews):\n",
        "    return np.mean([\n",
        "        click_through_rate(avg_ratings, num_reviews,\n",
        "                           np.repeat(d, len(avg_ratings)))\n",
        "        for d in [\"D\", \"DD\", \"DDD\", \"DDDD\"]\n",
        "    ],\n",
        "                   axis=0)\n",
        "\n",
        "  figsize(11, 5)\n",
        "  plot_fns([(\"{} Estimated CTR\".format(name), two_d_pred),\n",
        "            (\"CTR\", two_d_click_through_rate)],\n",
        "           split_by_dollar=False)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JVef4f8yUUbs"
      },
      "source": [
        "データセットに TensorFlow 勾配ブースティング決定木を適合できます。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DnPYlRAo2mnQ"
      },
      "outputs": [],
      "source": [
        "feature_columns = [\n",
        "    tf.feature_column.numeric_column(\"num_reviews\"),\n",
        "    tf.feature_column.numeric_column(\"avg_rating\"),\n",
        "]\n",
        "gbt_estimator = tf.estimator.BoostedTreesClassifier(\n",
        "    feature_columns=feature_columns,\n",
        "    # Hyper-params optimized on validation set.\n",
        "    n_batches_per_layer=1,\n",
        "    max_depth=3,\n",
        "    n_trees=20,\n",
        "    min_node_weight=0.1,\n",
        "    config=tf.estimator.RunConfig(tf_random_seed=42),\n",
        ")\n",
        "gbt_estimator.train(input_fn=train_input_fn)\n",
        "analyze_two_d_estimator(gbt_estimator, \"GBT\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nYZtd6YvsNdn"
      },
      "source": [
        "モデルは本来の CTR の一般的な形状をキャプチャし、まともな検証メトリックを使用していますが、入力空間のいくつかの部分に直感に反する振る舞いがあります。推定される CTR は平均評価またはレビュー数が増加するにつれ降下しているところです。これは、トレーニングデータセットがうまくカバーしていない領域のサンプルポイントが不足しているためです。単に、モデルにはデータのみから正しい振る舞いを推測する術がないのです。\n",
        "\n",
        "この問題を解決するには、モデルが平均評価とレビュー数の両方に対して単調的に増加する値を出力しなければならないように、形状制約を強制します。TFL にこれを実装する方法は、後で説明します。\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Uf7WqGooFiEp"
      },
      "source": [
        "## DNN を適合させる"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_s2aT3x0E_tF"
      },
      "source": [
        "DNN 分類器で、同じ手順を繰り返すことができます。レビュー数が少なく、十分なサンプルポイントがないため、同様の、意味をなさない外挿パターンとなります。検証メトリックが木のソリューションより優れていても、テストメトリックが悪化するところに注意してください。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gFUeG6kLDNhO"
      },
      "outputs": [],
      "source": [
        "feature_columns = [\n",
        "    tf.feature_column.numeric_column(\"num_reviews\"),\n",
        "    tf.feature_column.numeric_column(\"avg_rating\"),\n",
        "]\n",
        "dnn_estimator = tf.estimator.DNNClassifier(\n",
        "    feature_columns=feature_columns,\n",
        "    # Hyper-params optimized on validation set.\n",
        "    hidden_units=[16, 8, 8],\n",
        "    optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),\n",
        "    config=tf.estimator.RunConfig(tf_random_seed=42),\n",
        ")\n",
        "dnn_estimator.train(input_fn=train_input_fn)\n",
        "analyze_two_d_estimator(dnn_estimator, \"DNN\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0Avkw-okw7JL"
      },
      "source": [
        "## 形状制約"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3ExyethCFBrP"
      },
      "source": [
        "TensorFlow Lattice（TFL）の焦点は、トレーニングデータを超えてモデルの振る舞いを守るために形状制約を強制することに当てられます。形状制約は TFL Keras レイヤーに適用されます。その詳細は、[TensorFlow の JMLR 論文](http://jmlr.org/papers/volume17/15-243/15-243.pdf)をご覧ください。\n",
        "\n",
        "このチュートリアルでは、TF 缶詰 Estimator を使用してさまざまな形状制約を説明しますが、手順はすべて、TFL Keras レイヤーから作成されたモデルで実行することができます。\n",
        "\n",
        "ほかの TensorFlow Estimator と同様に、TFL 缶詰 Estimator では、[特徴量カラム](https://www.tensorflow.org/api_docs/python/tf/feature_column)を使用して入力形式を定義し、トレーニングの input_fn を使用してデータを渡します。TFL 缶詰 Estimator を使用するには、次の項目も必要です。\n",
        "\n",
        "- *モデルの構成*: モデルのアーキテクチャと特徴量ごとの形状制約とレギュラライザを定義します。\n",
        "- *特徴量分析 input_fn*: TFL 初期化を行うために TF input_fn でデータを渡します。\n",
        "\n",
        "より詳しい説明については、缶詰 Estimator のチュートリアルまたは API ドキュメントをご覧ください。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "anyCM4sCpOSo"
      },
      "source": [
        "### 単調性\n",
        "\n",
        "最初に、単調性形状制約を両方の特徴量に追加して、単調性に関する問題を解決します。\n",
        "\n",
        "TFL に形状制約を強制するように指示するには、*特徴量の構成*に制約を指定します。次のコードは、`monotonicity=\"increasing\"` を設定することによって、`num_reviews` と `avg_rating` の両方に対して単調的に出力を増加するようにする方法を示します。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FCm1lOjmwur_"
      },
      "outputs": [],
      "source": [
        "feature_columns = [\n",
        "    tf.feature_column.numeric_column(\"num_reviews\"),\n",
        "    tf.feature_column.numeric_column(\"avg_rating\"),\n",
        "]\n",
        "model_config = tfl.configs.CalibratedLatticeConfig(\n",
        "    feature_configs=[\n",
        "        tfl.configs.FeatureConfig(\n",
        "            name=\"num_reviews\",\n",
        "            lattice_size=2,\n",
        "            monotonicity=\"increasing\",\n",
        "            pwl_calibration_num_keypoints=20,\n",
        "        ),\n",
        "        tfl.configs.FeatureConfig(\n",
        "            name=\"avg_rating\",\n",
        "            lattice_size=2,\n",
        "            monotonicity=\"increasing\",\n",
        "            pwl_calibration_num_keypoints=20,\n",
        "        )\n",
        "    ])\n",
        "tfl_estimator = tfl.estimators.CannedClassifier(\n",
        "    feature_columns=feature_columns,\n",
        "    model_config=model_config,\n",
        "    feature_analysis_input_fn=feature_analysis_input_fn,\n",
        "    optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),\n",
        "    config=tf.estimator.RunConfig(tf_random_seed=42),\n",
        ")\n",
        "tfl_estimator.train(input_fn=train_input_fn)\n",
        "analyze_two_d_estimator(tfl_estimator, \"TF Lattice\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ubNRBCWW5wQ9"
      },
      "source": [
        "`CalibratedLatticeConfig` を使用して、各入力に*キャリブレータ*を適用（数値特徴量のピース単位の線形関数）してから*格子*レイヤーを適用して非線形的に較正済みの特徴量を融合する缶詰分類器を作成します。モデルの視覚化には、`tfl.visualization` を使用できます。特に、次のプロットは、缶詰分類器に含まれるトレーニング済みのキャリブレータを示します。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "both",
        "id": "C0py9Q6OBRBE"
      },
      "outputs": [],
      "source": [
        "def save_and_visualize_lattice(tfl_estimator):\n",
        "  saved_model_path = tfl_estimator.export_saved_model(\n",
        "      \"/tmp/TensorFlow_Lattice_101/\",\n",
        "      tf.estimator.export.build_parsing_serving_input_receiver_fn(\n",
        "          feature_spec=tf.feature_column.make_parse_example_spec(\n",
        "              feature_columns)))\n",
        "  model_graph = tfl.estimators.get_model_graph(saved_model_path)\n",
        "  figsize(8, 8)\n",
        "  tfl.visualization.draw_model_graph(model_graph)\n",
        "  return model_graph\n",
        "\n",
        "_ = save_and_visualize_lattice(tfl_estimator)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7vZ5fShXs504"
      },
      "source": [
        "制約が追加されると、推定される CTR は平均評価またはレビュー数が増加するにつれて、必ず増加するようになります。これは、キャリブレータと格子を確実に単調にすることで行われます。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RfniRZCHIvfK"
      },
      "source": [
        "### 収穫逓減\n",
        "\n",
        "[収穫逓減](https://en.wikipedia.org/wiki/Diminishing_returns)とは、特定の特徴量値を増加すると、それを高める上で得る限界利益は減少することを意味します。このケースでは、`num_reviews` 特徴慮はこのパターンに従うと予測されるため、それに合わせてキャリブレータを構成することができます。収穫逓減を次の 2 つの十分な条件に分けることができます。\n",
        "\n",
        "- キャリブレータが単調的に増加している\n",
        "- キャリブレータが凹状である\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XQrM9BskY-wx"
      },
      "outputs": [],
      "source": [
        "feature_columns = [\n",
        "    tf.feature_column.numeric_column(\"num_reviews\"),\n",
        "    tf.feature_column.numeric_column(\"avg_rating\"),\n",
        "]\n",
        "model_config = tfl.configs.CalibratedLatticeConfig(\n",
        "    feature_configs=[\n",
        "        tfl.configs.FeatureConfig(\n",
        "            name=\"num_reviews\",\n",
        "            lattice_size=2,\n",
        "            monotonicity=\"increasing\",\n",
        "            pwl_calibration_convexity=\"concave\",\n",
        "            pwl_calibration_num_keypoints=20,\n",
        "        ),\n",
        "        tfl.configs.FeatureConfig(\n",
        "            name=\"avg_rating\",\n",
        "            lattice_size=2,\n",
        "            monotonicity=\"increasing\",\n",
        "            pwl_calibration_num_keypoints=20,\n",
        "        )\n",
        "    ])\n",
        "tfl_estimator = tfl.estimators.CannedClassifier(\n",
        "    feature_columns=feature_columns,\n",
        "    model_config=model_config,\n",
        "    feature_analysis_input_fn=feature_analysis_input_fn,\n",
        "    optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),\n",
        "    config=tf.estimator.RunConfig(tf_random_seed=42),\n",
        ")\n",
        "tfl_estimator.train(input_fn=train_input_fn)\n",
        "analyze_two_d_estimator(tfl_estimator, \"TF Lattice\")\n",
        "_ = save_and_visualize_lattice(tfl_estimator)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LSmzHkPUo9u5"
      },
      "source": [
        "テストメトリックが、凹状の制約を追加することで改善しているのがわかります。予測図もグラウンドトゥルースにより似通っています。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "J6CP2Ovapiu3"
      },
      "source": [
        "### 2D 形状制約: 信頼\n",
        "\n",
        "1 つか 2 つのレビューのみを持つレストランの 5 つ星評価は、信頼できない評価である可能性があります（レストランは実際には良くない可能性があります）が、数百件のレビューのあるレストランの 4 つ星評価にははるかに高い信頼性があります（この場合、レストランは良い可能性があります）。レストランのレビュー数によって平均評価にどれほどの信頼を寄せるかが変化することを見ることができます。\n",
        "\n",
        "ある特徴量のより大きな（または小さな）値が別の特徴量の高い信頼性を示すことをモデルに指示する TFL 信頼制約を訓練することができます。これは、特徴量の構成で、`reflects_trust_in` 構成を設定することで実行できます。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OA14j0erm6TJ"
      },
      "outputs": [],
      "source": [
        "feature_columns = [\n",
        "    tf.feature_column.numeric_column(\"num_reviews\"),\n",
        "    tf.feature_column.numeric_column(\"avg_rating\"),\n",
        "]\n",
        "model_config = tfl.configs.CalibratedLatticeConfig(\n",
        "    feature_configs=[\n",
        "        tfl.configs.FeatureConfig(\n",
        "            name=\"num_reviews\",\n",
        "            lattice_size=2,\n",
        "            monotonicity=\"increasing\",\n",
        "            pwl_calibration_convexity=\"concave\",\n",
        "            pwl_calibration_num_keypoints=20,\n",
        "            # Larger num_reviews indicating more trust in avg_rating.\n",
        "            reflects_trust_in=[\n",
        "                tfl.configs.TrustConfig(\n",
        "                    feature_name=\"avg_rating\", trust_type=\"edgeworth\"),\n",
        "            ],\n",
        "        ),\n",
        "        tfl.configs.FeatureConfig(\n",
        "            name=\"avg_rating\",\n",
        "            lattice_size=2,\n",
        "            monotonicity=\"increasing\",\n",
        "            pwl_calibration_num_keypoints=20,\n",
        "        )\n",
        "    ])\n",
        "tfl_estimator = tfl.estimators.CannedClassifier(\n",
        "    feature_columns=feature_columns,\n",
        "    model_config=model_config,\n",
        "    feature_analysis_input_fn=feature_analysis_input_fn,\n",
        "    optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),\n",
        "    config=tf.estimator.RunConfig(tf_random_seed=42),\n",
        ")\n",
        "tfl_estimator.train(input_fn=train_input_fn)\n",
        "analyze_two_d_estimator(tfl_estimator, \"TF Lattice\")\n",
        "model_graph = save_and_visualize_lattice(tfl_estimator)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "puvP9X8XxyRV"
      },
      "source": [
        "次の図は、トレーニング済みの格子関数を示します。信頼制約により、較正済みの `num_reviews` のより大きな値によって、較正済みの `avg_rating` に対してより高い勾配が強制され、格子出力により大きな変化が生じることが期待されます。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "both",
        "id": "RounEQebxxnA"
      },
      "outputs": [],
      "source": [
        "lat_mesh_n = 12\n",
        "lat_mesh_x, lat_mesh_y = tfl.test_utils.two_dim_mesh_grid(\n",
        "    lat_mesh_n**2, 0, 0, 1, 1)\n",
        "lat_mesh_fn = tfl.test_utils.get_hypercube_interpolation_fn(\n",
        "    model_graph.output_node.weights.flatten())\n",
        "lat_mesh_z = [\n",
        "    lat_mesh_fn([lat_mesh_x.flatten()[i],\n",
        "                 lat_mesh_y.flatten()[i]]) for i in range(lat_mesh_n**2)\n",
        "]\n",
        "trust_plt = tfl.visualization.plot_outputs(\n",
        "    (lat_mesh_x, lat_mesh_y),\n",
        "    {\"Lattice Lookup\": lat_mesh_z},\n",
        "    figsize=(6, 6),\n",
        ")\n",
        "trust_plt.title(\"Trust\")\n",
        "trust_plt.xlabel(\"Calibrated avg_rating\")\n",
        "trust_plt.ylabel(\"Calibrated num_reviews\")\n",
        "trust_plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SKe3UHX6pUjw"
      },
      "source": [
        "### キャリブレータを平滑化する\n",
        "\n",
        "では、`avg_rating` のキャリブレータを見てみましょう。単調的に上昇してはいますが、勾配の変化は突然起こっており、解釈が困難です。そのため、`regularizer_configs` にレギュラライザーをセットアップして、このキャリブレータを平滑化したいと思います。\n",
        "\n",
        "ここでは、反りの変化を縮減するために `wrinkle` レギュラライザを適用します。また、`laplacian` レギュラライザを使用してキャリブレータを平らにし、`hessian` レギュラライザを使用してより線形にします。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qxFHH3hSpWfq"
      },
      "outputs": [],
      "source": [
        "feature_columns = [\n",
        "    tf.feature_column.numeric_column(\"num_reviews\"),\n",
        "    tf.feature_column.numeric_column(\"avg_rating\"),\n",
        "]\n",
        "model_config = tfl.configs.CalibratedLatticeConfig(\n",
        "    feature_configs=[\n",
        "        tfl.configs.FeatureConfig(\n",
        "            name=\"num_reviews\",\n",
        "            lattice_size=2,\n",
        "            monotonicity=\"increasing\",\n",
        "            pwl_calibration_convexity=\"concave\",\n",
        "            pwl_calibration_num_keypoints=20,\n",
        "            regularizer_configs=[\n",
        "                tfl.configs.RegularizerConfig(name=\"calib_wrinkle\", l2=1.0),\n",
        "            ],\n",
        "            reflects_trust_in=[\n",
        "                tfl.configs.TrustConfig(\n",
        "                    feature_name=\"avg_rating\", trust_type=\"edgeworth\"),\n",
        "            ],\n",
        "        ),\n",
        "        tfl.configs.FeatureConfig(\n",
        "            name=\"avg_rating\",\n",
        "            lattice_size=2,\n",
        "            monotonicity=\"increasing\",\n",
        "            pwl_calibration_num_keypoints=20,\n",
        "            regularizer_configs=[\n",
        "                tfl.configs.RegularizerConfig(name=\"calib_wrinkle\", l2=1.0),\n",
        "            ],\n",
        "        )\n",
        "    ])\n",
        "tfl_estimator = tfl.estimators.CannedClassifier(\n",
        "    feature_columns=feature_columns,\n",
        "    model_config=model_config,\n",
        "    feature_analysis_input_fn=feature_analysis_input_fn,\n",
        "    optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),\n",
        "    config=tf.estimator.RunConfig(tf_random_seed=42),\n",
        ")\n",
        "tfl_estimator.train(input_fn=train_input_fn)\n",
        "analyze_two_d_estimator(tfl_estimator, \"TF Lattice\")\n",
        "_ = save_and_visualize_lattice(tfl_estimator)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HHpp4goLvuPi"
      },
      "source": [
        "キャリブレータがスムーズになり、全体的な推定 CTR がグラウンドトゥルースにより一致するように改善されました。これは、テストメトリックと等高線図の両方に反映されます。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pSUd6aFlpYz4"
      },
      "source": [
        "### 分類較正の部分単調性\n",
        "\n",
        "これまで、モデルには 2 つの数値特徴量のみを使用してきました。ここでは、分類較正レイヤーを使用した 3 つ目の特徴量を追加します。もう一度、描画とメトリック計算用のヘルパー関数のセットアップから始めます。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5tLDKwTmjrLw"
      },
      "outputs": [],
      "source": [
        "def analyze_three_d_estimator(estimator, name):\n",
        "  # Extract validation metrics.\n",
        "  metric = estimator.evaluate(input_fn=val_input_fn)\n",
        "  print(\"Validation AUC: {}\".format(metric[\"auc\"]))\n",
        "  metric = estimator.evaluate(input_fn=test_input_fn)\n",
        "  print(\"Testing AUC: {}\".format(metric[\"auc\"]))\n",
        "\n",
        "  def three_d_pred(avg_ratings, num_reviews, dollar_rating):\n",
        "    results = estimator.predict(\n",
        "        tf.compat.v1.estimator.inputs.pandas_input_fn(\n",
        "            x=pd.DataFrame({\n",
        "                \"avg_rating\": avg_ratings,\n",
        "                \"num_reviews\": num_reviews,\n",
        "                \"dollar_rating\": dollar_rating,\n",
        "            }),\n",
        "            shuffle=False,\n",
        "        ))\n",
        "    return [x[\"logistic\"][0] for x in results]\n",
        "\n",
        "  figsize(11, 22)\n",
        "  plot_fns([(\"{} Estimated CTR\".format(name), three_d_pred),\n",
        "            (\"CTR\", click_through_rate)],\n",
        "           split_by_dollar=True)\n",
        "  "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CnPiqf4rq6kJ"
      },
      "source": [
        "3 つ目の特徴量である `dollar_rating` を追加するには、TFL での分類特徴量の取り扱いは、特徴量カラムと特徴量構成の両方においてわずかに異なることを思い出してください。ここでは、ほかのすべての特徴量が固定されている場合に、\"DD\" レストランの出力が \"D\" よりも大きくなるように、部分単調性を強制します。これは、特徴量構成の `monotonicity` 設定を使用して行います。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "m-w7iGEEpgGt"
      },
      "outputs": [],
      "source": [
        "feature_columns = [\n",
        "    tf.feature_column.numeric_column(\"num_reviews\"),\n",
        "    tf.feature_column.numeric_column(\"avg_rating\"),\n",
        "    tf.feature_column.categorical_column_with_vocabulary_list(\n",
        "        \"dollar_rating\",\n",
        "        vocabulary_list=[\"D\", \"DD\", \"DDD\", \"DDDD\"],\n",
        "        dtype=tf.string,\n",
        "        default_value=0),\n",
        "]\n",
        "model_config = tfl.configs.CalibratedLatticeConfig(\n",
        "    feature_configs=[\n",
        "        tfl.configs.FeatureConfig(\n",
        "            name=\"num_reviews\",\n",
        "            lattice_size=2,\n",
        "            monotonicity=\"increasing\",\n",
        "            pwl_calibration_convexity=\"concave\",\n",
        "            pwl_calibration_num_keypoints=20,\n",
        "            regularizer_configs=[\n",
        "                tfl.configs.RegularizerConfig(name=\"calib_wrinkle\", l2=1.0),\n",
        "            ],\n",
        "            reflects_trust_in=[\n",
        "                tfl.configs.TrustConfig(\n",
        "                    feature_name=\"avg_rating\", trust_type=\"edgeworth\"),\n",
        "            ],\n",
        "        ),\n",
        "        tfl.configs.FeatureConfig(\n",
        "            name=\"avg_rating\",\n",
        "            lattice_size=2,\n",
        "            monotonicity=\"increasing\",\n",
        "            pwl_calibration_num_keypoints=20,\n",
        "            regularizer_configs=[\n",
        "                tfl.configs.RegularizerConfig(name=\"calib_wrinkle\", l2=1.0),\n",
        "            ],\n",
        "        ),\n",
        "        tfl.configs.FeatureConfig(\n",
        "            name=\"dollar_rating\",\n",
        "            lattice_size=2,\n",
        "            pwl_calibration_num_keypoints=4,\n",
        "            # Here we only specify one monotonicity:\n",
        "            # `D` resturants has smaller value than `DD` restaurants\n",
        "            monotonicity=[(\"D\", \"DD\")],\n",
        "        ),\n",
        "    ])\n",
        "tfl_estimator = tfl.estimators.CannedClassifier(\n",
        "    feature_columns=feature_columns,\n",
        "    model_config=model_config,\n",
        "    feature_analysis_input_fn=feature_analysis_input_fn,\n",
        "    optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),\n",
        "    config=tf.estimator.RunConfig(tf_random_seed=42),\n",
        ")\n",
        "tfl_estimator.train(input_fn=train_input_fn)\n",
        "analyze_three_d_estimator(tfl_estimator, \"TF Lattice\")\n",
        "_ = save_and_visualize_lattice(tfl_estimator)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gdIzhYL79_Pp"
      },
      "source": [
        "この分類キャリブレータは、DD > D > DDD > DDDD というモデル出力の優先を示します。このセットアップではこれらは定数です。欠落する値のカラムもあることに注意してください。このチュートリアルのトレーニングデータとテストデータには欠落した特徴量はありませんが、ダウンストリームでモデルが使用される場合に値の欠落が生じたときには、モデルは欠損値の帰属を提供します。\n",
        "\n",
        "ここでは、`dollar_rating` で条件付けされたモデルの予測 CTR も描画します。必要なすべての制約が各スライスで満たされているところに注意してください。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rh0H2b6l_rwZ"
      },
      "source": [
        "### 出力較正"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KPb2ri4e7HXF"
      },
      "source": [
        "ここまでトレーニングしてきたすべての TFL モデルでは、格子レイヤー（モデルグラフで \"Lattice\" と示される部分）はモデル予測を直接出力しますが、格子出力をスケーリングし直してモデル出力を送信すべきかわからないことがたまにあります。\n",
        "\n",
        "- 特徴量が $log$ カウントでラベルがカウントである。\n",
        "- 格子は頂点をほとんど使用しないように構成されているが、ラベル分布は比較的複雑である。\n",
        "\n",
        "こういった場合には、格子出力とモデル出力の間に別のキャリブレータを追加して、モデルの柔軟性を高めることができます。ここでは、今作成したモデルにキーポイントを 5 つ使用したキャリブレータレイヤーを追加することにしましょう。また、出力キャリブレータのレギュラライザも追加して、関数の平滑性を維持します。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "k5Sg_gUj_0i4"
      },
      "outputs": [],
      "source": [
        "feature_columns = [\n",
        "    tf.feature_column.numeric_column(\"num_reviews\"),\n",
        "    tf.feature_column.numeric_column(\"avg_rating\"),\n",
        "    tf.feature_column.categorical_column_with_vocabulary_list(\n",
        "        \"dollar_rating\",\n",
        "        vocabulary_list=[\"D\", \"DD\", \"DDD\", \"DDDD\"],\n",
        "        dtype=tf.string,\n",
        "        default_value=0),\n",
        "]\n",
        "model_config = tfl.configs.CalibratedLatticeConfig(\n",
        "    output_calibration=True,\n",
        "    output_calibration_num_keypoints=5,\n",
        "    regularizer_configs=[\n",
        "        tfl.configs.RegularizerConfig(name=\"output_calib_wrinkle\", l2=0.1),\n",
        "    ],\n",
        "    feature_configs=[\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name=\"num_reviews\",\n",
        "        lattice_size=2,\n",
        "        monotonicity=\"increasing\",\n",
        "        pwl_calibration_convexity=\"concave\",\n",
        "        pwl_calibration_num_keypoints=20,\n",
        "        regularizer_configs=[\n",
        "            tfl.configs.RegularizerConfig(name=\"calib_wrinkle\", l2=1.0),\n",
        "        ],\n",
        "        reflects_trust_in=[\n",
        "            tfl.configs.TrustConfig(\n",
        "                feature_name=\"avg_rating\", trust_type=\"edgeworth\"),\n",
        "        ],\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name=\"avg_rating\",\n",
        "        lattice_size=2,\n",
        "        monotonicity=\"increasing\",\n",
        "        pwl_calibration_num_keypoints=20,\n",
        "        regularizer_configs=[\n",
        "            tfl.configs.RegularizerConfig(name=\"calib_wrinkle\", l2=1.0),\n",
        "        ],\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name=\"dollar_rating\",\n",
        "        lattice_size=2,\n",
        "        pwl_calibration_num_keypoints=4,\n",
        "        # Here we only specify one monotonicity:\n",
        "        # `D` resturants has smaller value than `DD` restaurants\n",
        "        monotonicity=[(\"D\", \"DD\")],\n",
        "    ),\n",
        "])\n",
        "tfl_estimator = tfl.estimators.CannedClassifier(\n",
        "    feature_columns=feature_columns,\n",
        "    model_config=model_config,\n",
        "    feature_analysis_input_fn=feature_analysis_input_fn,\n",
        "    optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),\n",
        "    config=tf.estimator.RunConfig(tf_random_seed=42),\n",
        ")\n",
        "tfl_estimator.train(input_fn=train_input_fn)\n",
        "analyze_three_d_estimator(tfl_estimator, \"TF Lattice\")\n",
        "_ = save_and_visualize_lattice(tfl_estimator)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TLOGDrYY0hH7"
      },
      "source": [
        "最後のテストメトリックとプロットは、常識的な制約を使用することで、モデルが予期しない振る舞いを回避して全体的な入力空間の外挿をいかに改善できるかを示します。"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "shape_constraints.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
